"""
Visualization functions for structural break analysis.

Generates:
1. Rolling-window coefficient evolution plot
2. Split-sample coefficient comparison
3. Post-tariff residual tracking for US/China
4. Break test summary heatmap
"""

import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pathlib import Path

OUTPUT_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/output")
FIG_DIR = OUTPUT_DIR / "figures"
TAB_DIR = OUTPUT_DIR / "tables"


def plot_rolling_coefficients(rolling_df, filename="rolling_coefficients.png"):
    """
    Plot Z_1, Z_2, Z_3 coefficients over rolling windows with confidence bands.
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    x = rolling_df['window_mid']

    # Z coefficients
    for i, (z, ax, color) in enumerate([
        ('Z_1', axes[0, 0], '#1f77b4'),
        ('Z_2', axes[0, 1], '#2ca02c'),
        ('Z_3', axes[1, 0], '#d62728'),
    ]):
        coef = rolling_df[f'{z}_coef']
        se = rolling_df[f'{z}_se']

        ax.plot(x, coef, color=color, linewidth=2, label=f'{z} coefficient')
        ax.fill_between(x, coef - 1.96 * se, coef + 1.96 * se,
                         alpha=0.2, color=color, label='95% CI')
        ax.axhline(y=0, color='black', linewidth=0.8, linestyle='-')

        # Mark regime breaks
        for year, label, ls in [(2001, 'WTO', '--'), (2008, 'GFC', ':'),
                                  (2018, 'Tariffs', '-.')]:
            ax.axvline(x=year, color='gray', linewidth=0.8, linestyle=ls,
                        alpha=0.5)
            ax.text(year, ax.get_ylim()[1] * 0.95, label, fontsize=7,
                     ha='center', color='gray')

        ax.set_title(f'{z} Coefficient Over Time', fontsize=12)
        ax.set_xlabel('Window midpoint', fontsize=10)
        ax.set_ylabel('Coefficient', fontsize=10)
        ax.legend(fontsize=8)
        ax.grid(alpha=0.3)

    # R-squared evolution
    ax = axes[1, 1]
    ax.plot(x, rolling_df['r_squared'], color='purple', linewidth=2)
    ax.fill_between(x, 0, rolling_df['r_squared'], alpha=0.1, color='purple')
    for year, label, ls in [(2001, 'WTO', '--'), (2008, 'GFC', ':'),
                              (2018, 'Tariffs', '-.')]:
        ax.axvline(x=year, color='gray', linewidth=0.8, linestyle=ls, alpha=0.5)
    ax.set_title('Model R² Over Time', fontsize=12)
    ax.set_xlabel('Window midpoint', fontsize=10)
    ax.set_ylabel('R²', fontsize=10)
    ax.set_ylim(0, max(rolling_df['r_squared']) * 1.3)
    ax.grid(alpha=0.3)

    plt.suptitle('Rolling-Window Estimation (15-year windows)',
                  fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig(FIG_DIR / filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved {filename}")


def plot_split_sample_comparison(split_results, filename="split_sample_comparison.png"):
    """
    Bar chart comparing Z coefficients pre vs post for each break.
    """
    breaks = [
        ('wto', 'WTO (2001)'),
        ('gfc', 'GFC (2008)'),
        ('tariff', 'Tariffs (2018)'),
    ]

    fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=False)

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    bar_width = 0.35

    for ax_idx, (label, title) in enumerate(breaks):
        ax = axes[ax_idx]

        if label not in split_results:
            ax.set_visible(False)
            continue

        pre_model, _ = split_results[label].get('pre', (None, None))
        post_model, _ = split_results[label].get('post', (None, None))

        if pre_model is None or post_model is None:
            ax.text(0.5, 0.5, 'Insufficient data', transform=ax.transAxes,
                     ha='center', va='center')
            ax.set_title(title)
            continue

        x = np.arange(len(z_vars))
        pre_coefs = [pre_model.beta[i] for i in range(3)]
        post_coefs = [post_model.beta[i] for i in range(3)]
        pre_se = [pre_model.se[i] for i in range(3)]
        post_se = [post_model.se[i] for i in range(3)]

        bars1 = ax.bar(x - bar_width / 2, pre_coefs, bar_width,
                         yerr=[1.96 * s for s in pre_se],
                         label='Pre', color='#1f77b4', alpha=0.7,
                         capsize=3, edgecolor='navy')
        bars2 = ax.bar(x + bar_width / 2, post_coefs, bar_width,
                         yerr=[1.96 * s for s in post_se],
                         label='Post', color='#ff7f0e', alpha=0.7,
                         capsize=3, edgecolor='darkorange')

        ax.axhline(y=0, color='black', linewidth=0.8)
        ax.set_xticks(x)
        ax.set_xticklabels(z_vars, fontsize=10)
        ax.set_title(title, fontsize=12)
        ax.legend(fontsize=9)
        ax.grid(axis='y', alpha=0.3)

        # Add R² annotation
        ax.text(0.02, 0.98,
                 f"Pre R²={pre_model.r_squared:.3f} (n={pre_model.n_obs})\n"
                 f"Post R²={post_model.r_squared:.3f} (n={post_model.n_obs})",
                 transform=ax.transAxes, fontsize=8, va='top',
                 bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    axes[0].set_ylabel('Coefficient', fontsize=11)
    plt.suptitle('Pre vs. Post Structural Break: Demographic Coefficients',
                  fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig(FIG_DIR / filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved {filename}")


def plot_tariff_residuals(tariff_yearly, tariff_year=2018,
                           filename="tariff_residuals.png"):
    """
    Plot year-by-year residuals for US, China, and other affected countries.
    """
    focus = ['USA', 'CHN', 'DEU', 'JPN', 'KOR', 'MEX', 'CAN']
    available = [c for c in focus if c in tariff_yearly]

    if not available:
        print("  No tariff residual data to plot")
        return

    n = len(available)
    n_cols = min(3, n)
    n_rows = (n + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows),
                              squeeze=False)

    for idx, iso3 in enumerate(available):
        row, col = idx // n_cols, idx % n_cols
        ax = axes[row, col]
        cdf = tariff_yearly[iso3]

        # Color bars by pre/post
        colors = ['#1f77b4' if y < tariff_year else '#d62728'
                   for y in cdf['year']]

        ax.bar(cdf['year'], cdf['resid_baseline'], color=colors, alpha=0.7,
                edgecolor='none')
        ax.axhline(y=0, color='black', linewidth=0.8)
        ax.axvline(x=tariff_year - 0.5, color='red', linewidth=1.5,
                    linestyle='--', alpha=0.7, label='Tariffs')

        # Annotate mean shift
        pre = cdf[cdf['year'] < tariff_year]['resid_baseline']
        post = cdf[cdf['year'] >= tariff_year]['resid_baseline']
        if len(pre) > 0 and len(post) > 0:
            shift = post.mean() - pre.mean()
            ax.text(0.98, 0.02,
                     f"Shift: {shift:+.2f} pp",
                     transform=ax.transAxes, fontsize=9, ha='right', va='bottom',
                     bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

        ax.set_title(iso3, fontsize=12, fontweight='bold')
        ax.set_ylabel('Residual (pp CA/GDP)', fontsize=9)
        ax.grid(axis='y', alpha=0.3)

    # Hide unused
    for idx in range(n, n_rows * n_cols):
        axes[idx // n_cols, idx % n_cols].set_visible(False)

    plt.suptitle('Model Residuals Before & After Tariffs (2018)',
                  fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig(FIG_DIR / filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved {filename}")


def plot_break_interaction_significance(break_models,
                                         filename="break_interaction_pvalues.png"):
    """
    Heatmap showing p-values of Z×post_break interaction terms
    across all break tests.
    """
    rows = []
    for label, (model, _) in break_models.items():
        if model is None:
            continue
        names = model.feature_names
        for z in ['Z_1', 'Z_2', 'Z_3']:
            int_name = f'{z}_x_post'
            if int_name in names:
                idx = names.index(int_name)
                rows.append({
                    'Break': label.upper(),
                    'Variable': z,
                    'Coefficient': model.beta[idx],
                    'p-value': model.pvalues[idx],
                })

    if not rows:
        print("  No interaction data for heatmap")
        return

    df = pd.DataFrame(rows)
    pivot_coef = df.pivot(index='Break', columns='Variable', values='Coefficient')
    pivot_pval = df.pivot(index='Break', columns='Variable', values='p-value')

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    # Coefficient heatmap
    im1 = ax1.imshow(pivot_coef.values, cmap='RdBu_r', aspect='auto',
                       vmin=-max(abs(pivot_coef.values.min()), abs(pivot_coef.values.max())),
                       vmax=max(abs(pivot_coef.values.min()), abs(pivot_coef.values.max())))
    ax1.set_xticks(range(len(pivot_coef.columns)))
    ax1.set_xticklabels(pivot_coef.columns)
    ax1.set_yticks(range(len(pivot_coef.index)))
    ax1.set_yticklabels(pivot_coef.index)
    ax1.set_title('Post-Break Interaction Coefficients')
    for i in range(len(pivot_coef.index)):
        for j in range(len(pivot_coef.columns)):
            ax1.text(j, i, f"{pivot_coef.values[i, j]:.2f}",
                      ha='center', va='center', fontsize=10)
    plt.colorbar(im1, ax=ax1)

    # P-value heatmap
    im2 = ax2.imshow(pivot_pval.values, cmap='RdYlGn_r', aspect='auto',
                       vmin=0, vmax=0.2)
    ax2.set_xticks(range(len(pivot_pval.columns)))
    ax2.set_xticklabels(pivot_pval.columns)
    ax2.set_yticks(range(len(pivot_pval.index)))
    ax2.set_yticklabels(pivot_pval.index)
    ax2.set_title('P-values (green = significant)')
    for i in range(len(pivot_pval.index)):
        for j in range(len(pivot_pval.columns)):
            val = pivot_pval.values[i, j]
            sig = '***' if val < 0.01 else '**' if val < 0.05 else '*' if val < 0.1 else ''
            ax2.text(j, i, f"{val:.3f}{sig}",
                      ha='center', va='center', fontsize=10)
    plt.colorbar(im2, ax=ax2)

    plt.suptitle('Structural Break Interaction Tests', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig(FIG_DIR / filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved {filename}")


def generate_all_break_figures(results):
    """Generate all structural break visualizations from results dict."""
    print("\n>>> Generating structural break figures <<<")

    if 'rolling' in results and len(results['rolling']) > 0:
        plot_rolling_coefficients(results['rolling'])

    if 'split_results' in results:
        plot_split_sample_comparison(results['split_results'])

    if 'tariff_residuals' in results:
        _, yearly = results['tariff_residuals']
        plot_tariff_residuals(yearly)

    if 'break_models' in results:
        plot_break_interaction_significance(results['break_models'])

    print("\n  All structural break figures generated.")
